package com.sc.security.service;

import com.sc.security.service.sms.ValidateCodeRepo;
import com.sc.security.validate.code.ValidateCodeException;
import com.sc.security.validate.code.sms.ValidateCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.social.connect.web.HttpSessionSessionStrategy;
import org.springframework.social.connect.web.SessionStrategy;
import org.springframework.web.bind.ServletRequestBindingException;
import org.springframework.web.bind.ServletRequestUtils;
import org.springframework.web.context.request.ServletWebRequest;

import java.io.IOException;
import java.util.Map;

/**
 * @author sc
 * Created on  2018/1/17
 */
@Slf4j
public abstract class AbstractValidateCodeProcessor<C extends ValidateCode> implements ValidateCodeProcessor{
    private SessionStrategy sessionStrategy = new HttpSessionSessionStrategy();
    /**
     * 系统中所有实现ValidateCodeGenerator接口的实现
     */
    @Autowired
    private Map<String,ValidateCodeGenerator> validateCodeGenerators;
    @Autowired
    private ValidateCodeRepo validateCodeRepo;

    @Override
    public void create(ServletWebRequest request) throws IOException, ServletRequestBindingException {
        C validateCode = generate(request);
        save(request,validateCode);
        send(request,validateCode);
    }

    /***
     * 子类实现
     * @param request
     * @param validateCode
     */
    protected abstract void send(ServletWebRequest request, C validateCode) throws IOException, ServletRequestBindingException;

    /***
     * 生成校验码
     * @param request
     * @return
     */
    private C generate(ServletWebRequest request) {
        String type = getProcessorType(request);
        ValidateCodeGenerator validateCodeGenerator = validateCodeGenerators.get(type+"CodeGenDefault");
        log.info("生成验证码");
        return (C) validateCodeGenerator.generate(request);
    }

    /***
     * 保存校验码
     * @param request
     * @param validateCode
     */
    private void save(ServletWebRequest request, C validateCode) {
        ValidateCode validate = new ValidateCode(validateCode.getCode(),validateCode.getExpire());
        validateCodeRepo.save(request,validate,getValidateCodeType());
        log.info("验证码保存 key:{},value{}","SESSION_KEY_CODE_"+getProcessorType(request).toUpperCase(),validate.getCode());
    }


    private String getProcessorType(ServletWebRequest request){
        return StringUtils.substringAfter(request.getRequest().getRequestURI(),"/code/");
    }

    @Override
    public void validate(ServletWebRequest request) {
        ValidateCodeType processorType = getValidateCodeType();
        C codeInSession = (C) validateCodeRepo.get(request, processorType);

        String codeInRequest;
        try {
            codeInRequest = ServletRequestUtils.getStringParameter(request.getRequest(),
                    processorType.getParamNameOnValidate());
        } catch (ServletRequestBindingException e) {
            throw new ValidateCodeException("获取验证码的值失败");
        }

        if (StringUtils.isBlank(codeInRequest)) {
            throw new ValidateCodeException(processorType + "验证码的值不能为空");
        }

        if (codeInSession == null) {
            throw new ValidateCodeException(processorType + "验证码不存在");
        }

        if (codeInSession.isExpired()) {
            validateCodeRepo.remove(request, processorType);
            throw new ValidateCodeException(processorType + "验证码已过期");
        }

        if (!StringUtils.equals(codeInSession.getCode(), codeInRequest)) {
            throw new ValidateCodeException(processorType + "验证码不匹配");
        }

        validateCodeRepo.remove(request, processorType);
    }

    private ValidateCodeType getValidateCodeType() {
        String type = StringUtils.substringBefore(getClass().getSimpleName(), "CodeProcessor");
        return ValidateCodeType.valueOf(type.toUpperCase());
    }
}
